import numpy as np
import csv, os, string, pickle
from tqdm import tqdm
from sklearn.compose import ColumnTransformer
import training
import train_word2vec
import torch
import json
from datetime import datetime
import time
from datetime import timedelta
import pandas as pd
import numpy as np
import scipy.stats as st
import argparse
class Arguments():
    def __init__(self):
        self.data_path = None
        self.results_folder = None
        self.temp_folder = None
        self.docs_folder = None
        self.pretrained_vectors = False
        self.prev_model_file = None

        ## Learning parameters
        self.opt_type = "rms" 
        self.lr = 0.0001 
        self.nepochs = 50 
        self.dropout_p = 0.5
        self.nfolds = 1
        self.resample = 0
        self.w_decay=0
        

        ## Network structure params
        self.model_id=None
        self.n_layers = 3
        self.embed_dim = 300
        self.c_dim = 100
        self.h_dim = 256

        self.representation_dim=4096
        
        #predict n words parameter
        self.t=1

        ## Synthetic experiment arguments
        self.synth_args = None

class SupArgs():
    def __init__(self,best_C,best_solver,accu,cv_duration):
        self.best_C = best_C
        self.best_solver = best_solver
        self.accu = accu
        self.cv_duration = cv_duration
 

def train_model(args):
    model = training.train_model(model_id=args.model_id,
                                            data_path=args.data_path,
                                            results_folder=args.results_folder,
                                            t=args.t,
                                            c_dim=args.c_dim, 
                                            h_dim=args.h_dim,
                                            representation_dim=args.representation_dim,
                                            nepochs=args.nepochs, 
                                            lr=args.lr, 
                                            embed_dim=args.embed_dim, 
                                            opt_type=args.opt_type,
                                            dropout_p=args.dropout_p,
                                            n_layers=args.n_layers,
                                            nfolds=args.nfolds,
                                            resample=args.resample,
                                            w_decay=args.w_decay,
                                            pretrained_vectors=args.pretrained_vectors, 
                                            temp_model_folder=args.temp_folder, 
                                            prev_model_file=args.prev_model_file,
                                            presampled_docs_file=args.docs_folder,
                                            synthetic_args=args.synth_args)


    return(model)

'''
Helper functions
'''
def saveClassifierDataframe(df,folder_name,file_name):
    file_path=os.path.join(folder_name, file_name)
    with open(file_path, 'wb') as f:
        pickle.dump(df, f)
    print("Dataframe saved...")
    return file_path


def saveModel(model,path):
    torch.save(model,path)
    print("model save at path:",path)


'''
Experiment stats and some results is updated to the file: 'models/meta_model.json' - for first time run, should make a new empty file named as such.
'''
def saveModelStats(args,dataset,unsup_id,unsup_duration=None,sup_duration=None,acc=None,best_c=None,
                    best_solver=None,cv_duration=None,df_filepath=None,best_model_file_name=None):
    folder_name = 'models/models_'+dataset
    if(not os.path.isdir(folder_name)):
        os.makedirs(folder_name)
    file_name='meta_model.json'
    cur={}
    cur['model_id']=args.model_id
    cur["time saved"]=datetime.now().strftime("%d/%m/%Y %H:%M:%S")
    cur["layers"]=args.n_layers 
    cur["embed_dim"]=args.embed_dim ## embedding layer dimension
    cur["h_dim"] = args.h_dim     ## hidden layer dimension 512
    cur["representation_dim"] = args.representation_dim # representation dimension
    cur["nepochs"]=args.nepochs   ## number of epochs
    cur["lr"]= args.lr       ## learning rate 0.0002
    cur["dropout_p"]=args.dropout_p   ## dropout
    cur['resample'] = args.resample    ## frequency of resampling data
    cur["opt_type"]=  args.opt_type   ## optimizer
    cur["t"]= args.t   #num words in labels
    cur["w_decay"]= args.w_decay #w_decay

    #supervised result
    # cur["best_C"]=sup_args.best_C
    # cur["best_solver"]=sup_args.best_solver
    # cur["accuracy"]=sup_args.accu
    # cur["cv_duration"]=sup_args.cv_duration
    if unsup_duration:
        cur["unsup_duration"]=unsup_duration
    if sup_duration:
        cur["sup_duration"]=sup_duration
    if acc:
        cur['accuracy']=acc
    if best_c:
        cur['best_c']=best_c
    if best_solver:
        cur['best_solver']=best_solver
    if cv_duration:
        cur['cv_duration']=cv_duration
    if df_filepath:
        cur['df_filepath']=df_filepath
    if best_model_file_name:
        cur['best_model_file_name']=best_model_file_name
    #cur["n_samples"]=n_samples

    
    #updating 
    with open(os.path.join(folder_name,file_name), 'r') as openfile:
        obj = json.load(openfile)
    obj[unsup_id]=cur
    updated_obj = json.dumps(obj, indent = 4)
    
    with open(os.path.join(folder_name,file_name), "w") as outfile:
        outfile.write(updated_obj)



D_SAMPLE_ARR = {
    "agnews":[100,600,1100,1600,2100,2600,3100,3600,4000],
    "imdb":[50,250,500,700,900,1100,1300,1500,1700,2000],
    "dbpedia":[400,1900,3400,4900,6400,7900,9400,10900,12400,14000],
    "smalldbpedia":[100,600,1100,1600,2100,2600,3100,3600,4000]
}
'''
Generate dataframes (supervised training sample vs accuracy + CI interval) for different embedding types

Available embedding types: 'model','BOW','word2vec','lda'

Return df
'''
def generateLCDataframe(dataset='agnews',predictmodel=None,embedding_type='model',contrastive=False,word2vec_emb_size=300,lda_emb_size=50,custom_embedding= False):

    #Run Linear Classifier
    print("Running supervised learning...")
    sup_starttime=time.time()

    data=[]
    N=10 #reptition for CI
    for n_samples in tqdm(D_SAMPLE_ARR[dataset]):
        accu_list=[]
        best_c_list=[]
        best_solver_list=[]
        cv_duration_list=[]
        for n in tqdm(range(N)):
            print('Running reptition number: ',n,' with number of data used: ',n_samples)

            best_c,best_solver,accu,cv_duration=training.train_classifier(dataset=dataset,model=predictmodel,\
                n_samples=n_samples,embedding_type=embedding_type,contrastive=contrastive,\
                word2vec_emb_size=word2vec_emb_size,lda_emb_size=lda_emb_size,\
                custom_embedding = custom_embedding)
            
            accu_list.append(accu)
            best_c_list.append(best_c)
            best_solver_list.append(best_solver)
            cv_duration_list.append(cv_duration)
        mean=np.mean(accu_list)
        lb,ub=st.t.interval(alpha=0.95, df=len(accu_list)-1, loc=mean, scale=st.sem(accu_list)) 
        data.append(
            {
                'n_samples':n_samples,
                'mean_accuracy':mean,
                'lb':lb,
                'ub':ub,
                'best_c_list':best_c_list,
                'best_solver_list':best_solver_list,
                'cv_duration_list':cv_duration_list
            }
        )
    df = pd.DataFrame(data)
    sup_endtime=time.time()
    sup_duration=str(timedelta(seconds=sup_endtime - sup_starttime))
    return df,sup_duration
    

'''
run_type: "one"|"df"
dataset: "agnews"|"imdb"|"dbpedia"|"smalldbpedia"
'''
def runExperiment_model(dataset = "agnews",run_type="one",skip_unsup=False,unsup_id=None,model_id=None,n_layers=None,
                        representation_dim=None,embed_dim=5000,c_dim = 100,h_dim=None,nepochs=150,opt_type='amsgrad',w_decay=0,
                        resample=2,dropout_p=0,lr=0.0002,prev_model_file=None,n_samples=4000,custom_embedding = False):

    unsup_id=unsup_id#'unsup8.29_newreplayer_400'
    print(unsup_id)
    ## Fit the model
    args = Arguments()
    args.temp_folder = f'models/models_{str(dataset)}' ## Temporary folder to hold intermediate models
    if(not os.path.isdir(args.temp_folder)):
        os.mkdir(args.temp_folder)

    args.data_path = f'data/data_{str(dataset)}' ## Folder for experiment data
    if (not os.path.isdir(args.data_path)):
        os.mkdir(args.data_path)

    ## Folder to hold results
    args.results_folder =  f"results/results_{str(dataset)}/{str(unsup_id)}"
    if(not os.path.isdir(args.results_folder)):
        os.makedirs(args.results_folder) 

    ## NN parameters
    args.model_id=model_id  #word2vec
    args.n_layers = n_layers  #3
    args.c_dim = c_dim        ## final layer dimension
    args.embed_dim = embed_dim    ## embedding layer dimension
    args.h_dim = h_dim        ## hidden layer dimension 512
    args.representation_dim=representation_dim #400
    


    args.nepochs = nepochs     ## number of epochs #DEBUG
    args.lr = lr       ## learning rate 0.0002
    args.dropout_p = dropout_p   ## dropout
    args.resample = resample      ## frequency of resampling data
    args.opt_type = opt_type   ## optimizer
    args.t=4      #num word    s in labels
    args.w_decay=w_decay #w_decay0.01

    
    ##Run Unsupervised
    print('Rerun unsupervised learning....')
    unsup_starttime=time.time()

    model=None
    if not skip_unsup:
        model = train_model(args)
        saveModel(model,f'models/models_{str(dataset)}/'+str(unsup_id)+'_model.pt')
    else:
        if prev_model_file:
            model=torch.load(prev_model_file)
        else:
            print("No model exception. You should pass a previous model file if you want to skip unsupervised learning")
    unsup_endtime=time.time()
    unsup_duration=str(timedelta(seconds=unsup_endtime - unsup_starttime))

    

    #Run Linear Classifier
    print("Running supervised learning...")

    def generateDfRun(model,model_id,custom_embedding=False):
        if model_id=="contrastive":
            contrastive=True
        else:
            contrastive=False
        df,duration = generateLCDataframe(dataset=dataset,predictmodel=model,contrastive=contrastive,custom_embedding=custom_embedding)
        
        # sup_args=SupArgs(best_C,best_solver,accu,cv_duration)
    
        print('Saving dataframe and model stats...')

        #saveModel(model,'models/'+str(unsup_id)+'_model.pt')
        file_path=saveClassifierDataframe(df,args.results_folder,unsup_id+'_accuracy_new_df.pkl')

        saveModelStats(args,dataset,unsup_id,unsup_duration=unsup_duration,sup_duration=duration,df_filepath=file_path,best_model_file_name=best_model_file_name)
        print("Stats Saved") #in meta_model.json


    def quickTestRun(model,n_samples,model_id):
        if model_id=="contrastive":
            contrastive=True
        else:
            contrastive=False
        best_c,best_solver,accu,cv_duration=training.train_classifier(model=model,n_samples=n_samples,embedding_type='model',contrastive=contrastive)
        saveModelStats(args,dataset,unsup_id,acc=accu,unsup_duration=unsup_duration,best_c=best_c,
                       best_solver=best_solver,cv_duration=cv_duration)
        return accu

    if run_type=="one":
        print(f"running one time on {n_samples} sample")
        quickTestRun(model,n_samples,model_id)
    elif run_type=="df":
        print("running df")
        generateDfRun(model,model_id,custom_embedding = custom_embedding)
    else:
        print("Wrong run_type specified")
    # accu=quickTestRun(model,4000,model_id)
    # return accu


def runExperiment_baseline(dataset="agnews",run_type="one",unsup_id=None,embedding_type='BOW',n_samples=4000,\
    word2vec_emb_size=300,lda_emb_size = 50,random = False):

    unsup_id=unsup_id#'unsup8.29_newreplayer_400'
    print(unsup_id)
    args = Arguments()
    args.temp_folder = f"models/models_{str(dataset)}" ## Temporary folder to hold intermediate models
    if(not os.path.isdir(args.temp_folder)):
        os.makedirs(args.temp_folder)

    args.data_path = f'data/data_{str(dataset)}' ## Folder for experiment data
    if (not os.path.isdir(args.data_path)):
        os.mkdir(args.data_path)

    ## Folder to hold results
    args.results_folder =  f"results/results_{str(dataset)}/{str(unsup_id)}"
    if(not os.path.isdir(args.results_folder)):
        os.makedirs(args.results_folder) 

    def generateDfRun():
        #generate data frame
        df,duration=generateLCDataframe(dataset=dataset,embedding_type=embedding_type,word2vec_emb_size=word2vec_emb_size,lda_emb_size=lda_emb_size)
        print('supervised duaration: ',duration)
        saveClassifierDataframe(df,args.results_folder,unsup_id+'_accuracydf.pkl')
    
    def quickTestRun(n_samples):
        _,_,accu,_=training.train_classifier(dataset=dataset,n_samples=n_samples,embedding_type=embedding_type,\
            word2vec_emb_size=word2vec_emb_size,random = random,lda_emb_size=lda_emb_size)
        print(f'Accuracy is {accu}')
        return accu
    
    if run_type=='one':
        return quickTestRun(n_samples)
    else:
        generateDfRun()
        return None

    #Set number of one-time training sample here
    # accu=quickTestRun(4000)
    # print("Finished run of id:",unsup_id," with embedding_type: ",embedding_type," with accuracy: ",accu)

def generateCI_word2vec_emb_size_tuning(dataset,id,emb_size_array,N=10,n_samples = 2000):
    #Run Linear Classifier
    args = Arguments()
    args.temp_folder = f"models/models_{str(dataset)}" ## Temporary folder to hold intermediate models
    if(not os.path.isdir(args.temp_folder)):
        os.mkdir(args.temp_folder)

    args.data_path = f'data/data_{str(dataset)}' ## Folder for experiment data
    if (not os.path.isdir(args.data_path)):
        os.mkdir(args.data_path)

    ## Folder to hold results
    args.results_folder =  f"results/results_{str(dataset)}/{str(id)}"
    if(not os.path.isdir(args.results_folder)):
        os.makedirs(args.results_folder) 



    print("Running CI word2vec...")
    sup_starttime=time.time()

    data=[]
    # N=10 #reptition for CI
    for i,embedding_size in enumerate(emb_size_array):
        accu_list=[]
        for n in range(N):
            print('Running reptition number: ',n,' with emb_size ',embedding_size)
            seed = i*N + n
            word2vec_matrix = train_word2vec.train_word2vec_unsupervised(dataset=dataset,embedding_size=embedding_size,window=5,workers=1,seed = seed)
            best_c,best_solver,accu,cv_duration=training.train_classifier(dataset = dataset,embedding_type='word2vec',word2vec_emb_size=embedding_size,word2vec_matrix=word2vec_matrix,n_samples=n_samples)
            # print(word2vec_matrix)
            accu_list.append(accu)
        print(accu_list) #debug
        mean=np.mean(accu_list)
        lb,ub=st.t.interval(alpha=0.95, df=len(accu_list)-1, loc=mean, scale=st.sem(accu_list)) 
        data.append(
            {
                'embedding_dimension':embedding_size,
                'mean_accuracy':mean,
                'lb':lb,
                'ub':ub,
            }
        )
    df = pd.DataFrame(data)
    sup_endtime=time.time()
    sup_duration=str(timedelta(seconds=sup_endtime - sup_starttime))

    saveClassifierDataframe(df,args.results_folder,id+'_accuracydf.pkl')
    print(f'total time used: {sup_duration}')
    return df,sup_duration




if __name__=='__main__':
    '''
    REMEMBER TO CHANGE BATCH SIZE ACROSS RUNS in Training.py file 
    '''
    